from __future__ import print_function
import torch
from torch import nn, optim
from torchvision.utils import save_image
import os
from tqdm import tqdm

from vae_models.vae_builder import VAEBuilder

builder = VAEBuilder()
args, device = builder.get_arguments()
train_loader, test_loader = builder.get_dataset()
summary_writer = builder.get_summary_writer()
model = builder.build_vae(device, args)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# builder.load_initial_checkpoint(model)


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        losses = model.loss_function(recon_batch, data, mu, logvar)
        
        if torch.isnan(losses.total):
            import pdb; pdb.set_trace()
        
        losses.total.backward()
        
        # if torch.isnan(model.fc1_sigma.weight.grad).any():
        #     import pdb; pdb.set_trace()
        
        train_loss += losses.elbo.item()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tMSE: {:.6f}\tKL: {:.6f}\tlog_sigma: {:f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       losses.rec.item() / len(data),
                       losses.KLD.item() / len(data),
                recon_batch.sigma.mean()))
         
        # if batch_idx == 200:
        #     break
            
    train_loss /=  len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss))
    summary_writer.add_scalar('train/elbo', train_loss, epoch)
    summary_writer.add_scalar('train/rec', losses.rec.item() / len(data), epoch)
    summary_writer.add_scalar('train/kld', losses.KLD.item() / len(data), epoch)
    if 'sigma_est_rec' in losses:
        summary_writer.add_scalar('train/sigma_est_rec', losses.sigma_est_rec.item() / len(data), epoch)
    if 'sigma_est' in recon_batch:
        summary_writer.add_scalar('train/log_sigma_est', recon_batch.sigma_est.mean(), epoch)
    summary_writer.add_scalar('train/log_sigma', recon_batch.sigma.mean(), epoch)
    summary_writer.add_scalar('train/beta', 2 * recon_batch.sigma.exp().mean() ** 2, epoch)


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(tqdm(test_loader)):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            # Pass the second value from posthoc VAE
            test_loss += model.loss_function(recon_batch, data, mu, logvar, phase='test').elbo.item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.mle.view(args.batch_size, -1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'vae_logs/{}/reconstruction_{}.png'.format(args.log_dir, str(epoch)), nrow=n)
                break
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    summary_writer.add_scalar('test/elbo', test_loss, epoch)


if __name__ == "__main__":
    for epoch in range(1, args.epochs + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            sample = model.sample(64).cpu()
            save_image(sample.view(64, -1, 28, 28),
                       'vae_logs/{}/sample_{}.png'.format(args.log_dir, str(epoch)))
        summary_writer.file_writer.flush()
        
    torch.save(model.state_dict(), 'vae_logs/{}/checkpoint_{}.pt'.format(args.log_dir, str(epoch)))
